{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "batchSize = 100\n",
    "trainSet = torchvision.datasets.MNIST(root='./data', train = True, transform=transforms.ToTensor(), download=True)\n",
    "trainLoader = torch.utils.data.DataLoader(dataset=trainSet, batch_size=batchSize, shuffle = True)\n",
    "testSet = torchvision.datasets.MNIST(root='./data', train = False, transform=transforms.ToTensor(), download=True)\n",
    "testLoader = torch.utils.data.DataLoader(dataset=testSet, batch_size=batchSize, shuffle = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Model5_2(nn.Module):\n",
    "    def __init__(self):       \n",
    "        self.inSize=28\n",
    "        self.hiddenSize=100\n",
    "        self.numLayers=2\n",
    "        self.outSize = 10\n",
    "        super(Model5_2, self).__init__()\n",
    "        self.lstm = nn.LSTM(self.inSize, self.hiddenSize, self.numLayers, batch_first=True)\n",
    "        self.fc = nn.Linear(self.hiddenSize, self.outSize)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        h0 = torch.zeros(self.numLayers, x.size(0), self.hiddenSize)\n",
    "        c0 = torch.zeros(self.numLayers, x.size(0), self.hiddenSize)\n",
    "        out, (hn, cn) = self.lstm(x, (h0,c0)) \n",
    "        out = self.fc(out[:, -1, :]) \n",
    "        return out\n",
    "\n",
    "model5_2 = Model5_2()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "import time\n",
    "\n",
    "def benchmark(trainLoader,testLoader, model, epochs=1, lr=0.01):\n",
    "    model.__init__()\n",
    "    start=time.time()\n",
    "    optimiser = optim.SGD(model.parameters(), lr=lr)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "    for epoch in range(epochs):\n",
    "        for i, (images, labels) in enumerate(trainLoader):\n",
    "            optimiser.zero_grad()\n",
    "            outputs = model(images.view(-1, 28,28))\n",
    "            loss = criterion(outputs, labels)\n",
    "            loss.backward()\n",
    "            optimiser.step()\n",
    "    print('Accuracy: {0:.4f}'.format(accuracy(testLoader,model)))\n",
    "    print('Training time: {0:.2f}'.format(time.time() - start))\n",
    "    \n",
    "def accuracy(testLoader,model):    \n",
    "    correct, total = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for data in testLoader:\n",
    "            images, labels = data\n",
    "            outputs = model(images.view(-1, 28,28))\n",
    "            _, predicted = torch.max(outputs.data, 1)\n",
    "            total += labels.size(0)\n",
    "            correct += (predicted == labels).sum().item()   \n",
    "    return(correct / total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2 Layer LSTM\n",
      "Accuracy: 0.9419\n",
      "Training time: 645.02\n"
     ]
    }
   ],
   "source": [
    "print('2 Layer LSTM')\n",
    "benchmark(trainLoader, testLoader,model5_2, epochs=5,lr=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(list(model5_2.parameters()))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([400, 28])\n",
      "torch.Size([400, 100])\n",
      "torch.Size([400])\n",
      "torch.Size([400])\n",
      "torch.Size([400, 100])\n",
      "torch.Size([400, 100])\n",
      "torch.Size([400])\n",
      "torch.Size([400])\n",
      "torch.Size([10, 100])\n",
      "torch.Size([10])\n"
     ]
    }
   ],
   "source": [
    "for i in range(len(list(model5_2.parameters()))):\n",
    "               print(list(model5_2.parameters())[i].size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
